[ExecuTorch][WebGPU] Dynamic resize hook for SDPA (live seq-len S)#20580
[ExecuTorch][WebGPU] Dynamic resize hook for SDPA (live seq-len S)#20580JulianCloudNTH wants to merge 2 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20580
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 PendingAs of commit 6ee7775 with merge base 0cef6de ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review and check for any areas or opportunities for modularization |
|
Claude finished @JulianCloudNTH's task in 2m 39s —— View job Review: Dynamic resize hook for SDPA (live seq-len S)
The change is well-conceived: keying the recompute on the q tensor (always) and Modularization (the main ask)1. The hook re-implements the five-dispatch build, line for line. The biggest win here. Dispatches 1–5 ( Extracting a single struct + function that maps struct SdpaStepCounts {
UpdateCacheParams uc;
AttnWeightsParams qk;
SoftmaxParams softmax;
ComputeOutParams av;
uint32_t uc_wgc, qk_wgc, softmax_wgc, av_wgc;
};
// Computes every per-step quantity from the live (s, pos); shared by the
// initial build and the resize hook so the two can never drift.
static SdpaStepCounts compute_sdpa_step(
WGPUDevice device, int64_t s, int64_t pos, const SdpaDims& d,
uint32_t uc_wg, uint32_t qk_wg, uint32_t av_wg);The build path calls it once to populate the initial UBOs/ 2. The 24-entry lambda capture list ( 3. A tile-grid helper. CorrectnessScratch-capacity guard is implicit. The scratch buffers are sized at build for Redundant double-fire when both triggers are dirty. With Minor
Overall this is a clean extension of the dynamic path and the Vulkan |
Stack from ghstack (oldest at bottom):
Make
sdpa_with_kv_cacheserve any live seq-len S from one graph (batched prefill S=K and decode S=1).Problem: the existing dynamic path only reacted to a live
input_pos(decode), with S captured at build time. It rewrote the QK dispatch (which depends oncontext_len) but leftupdate_cache, softmax, and AV sized for the build-time S. Under a dynamic seq-len S (one graph serving prefill and decode),kv_numel, the QK/AV tile grids, and the softmax row count all depend on S and were stale.Solution: a single recompute hook driven by either a live S (q tensor resize) or a live
input_pos(SymInt), recomputing every per-step quantity from the live shape.input_pos; recomputes ctx + QK count; S fixed.input_pos(when SymInt); reads live S fromcur_dims(q)and live pos, recomputes all five dispatches' counts + UBOs (update_cacheK/V, QK, softmax, AV), and sets the outputcur_dimsto q's.Implementation:
update_cache/softmax/AV dispatch indices (previously only QK) so their workgroup counts can be rewritten per step.Hq*ceil(S/TM)*ceil(ctx-or-D/TN)); softmax is one workgroup perHq*Srow.DynamicDispatchNode(recompute workgroups per execute); scratch is sized at build (S=max, ctx=Cmax) so buffers never move and bind groups stay valid.Constraints: fp32-only, batch=1, GQA,
is_causal=true,D%4==0invariants unchanged; the static / decode-only paths are unaffected (the q hook never fires without a resize).Co-authored-with: Claude Code.
Differential Revision: D109906097